# -*- coding: utf-8 -*-
"""
Created on Tue Jul 27 15:04:31 2021

@author: Lenovo
"""

import numpy as np
from scipy.optimize import linprog
import pandas as pd
from os.path import isfile


#=============================================================================
#The Test
#=============================================================================

def PQR_risk_test_find_e( C, P ):
    high_e = 1.0
    low_e = 0.0
    e = 1.0
    #Will binary search until we have narrowed the window enough.
    while high_e - low_e >= 2**(-15):
        if PQR_risk_test( C,P, e ):
            low_e = e #Our e passes so raise the lower bound.
        else:
            high_e = e #Our e failed so lower the upper bound.
        e = (high_e + low_e) / 2.0 #New e is the midpoint.
    return e


def PQR_risk_test(C,P,e = 1.0):
    
    N = C.shape[0]
    
    #Make a list of lotteries
    lotteries = np.concatenate( (C[:,[0,1]], C[:,[2,3]]),axis=0)
    #Make lottery list unique
    lotteries = np.unique( lotteries, axis = 0 )
    #Turn lotteries into a tuple of lists (instead of a list of numpy arrays)
    lotteries = [ tuple(lot) for lot in lotteries ]
    
    #make lottery dictionary. Pass it a lottery (a tuple) it will return the index number.
    lottery_dict = { lotteries[i]:i for i in range(len(lotteries)) }
    
    #For each observation make an RP list. That is, a list of inferior lottery pairs.
    RP_list = [ [ ( lottery_dict[lot1], lottery_dict[lot2] ) for lot1 in lotteries for lot2 in lotteries if (e * (P[n,:] @ C[n,:])) > P[n,:] @ np.array(lot1 + lot2) ] for n in range(N) ]
    
    #Make linear program where the number of lotteries is the number of vars and the number of total RPs is the number of constraints.
    varsN = len(lotteries)
    constraintsN = sum( len( RP_list[n] ) for n in range(len(RP_list)) )
    
    #Constraint matrix
    A = np.zeros( (constraintsN, varsN) )
    
    constrainti = 0
    for ob in range(N):
        #Get the lottery indices for ob.
        ob_lot_index = ( lottery_dict[ tuple( C[ob,[0,1]] ) ], lottery_dict[ tuple( C[ob,[2,3]] ) ] )
        
        RPs = RP_list[ob] #Get the constraints for the current observation.
        for RP in RPs:
            #Write the A matrix for the current ob
            for var in ob_lot_index:
                A[constrainti, var] = A[constrainti, var].copy() - 1.0
            #Write the Amatrix for the RPed lottery.
            for var in RP:
                A[constrainti, var] = A[constrainti, var].copy() + 1.0
            constrainti = constrainti + 1
    
    
    #Make obj and constraint vectors
    obj = np.zeros( varsN )
    constraint_vector = -np.ones( constraintsN )
    sol = linprog( c = obj, A_ub = A, b_ub = constraint_vector, method = 'highs-ds')
    
    if sol.status == 0:
        res = True
    elif sol.status == 2:
        res = False
    else:
        raise Exception('PQR risk test linprog returned status {0}'.format(sol.status))
    return res


#==============================================================================
#Perform the test
#==============================================================================

columns = ['PQR_risk', 'cat']

#=============================================================================
#INPUT FILE
#=============================================================================
df = pd.read_csv( 'ITCR_cdata.csv' )
#Get rid of no risk questions
df = df[ (df['pz'] != 66666.) & (df['pw'] != 66666.) ]
ids = sorted( set( df['participantid'] ) )

#Put in edges
df['edges'] = 0
zeros_count = (df[['x','y','z','w']] == 0.).sum(1)
df.loc[ zeros_count >= 1, 'edges' ] = 1

#=============================================================================
#RESULT FILE
#=============================================================================

#-------------------
#manage results file
#-------------------
results_file = 'PQR_risk_test_by_participants.xlsx'
results_df = None
if isfile( results_file ):
    results_df = pd.read_excel( results_file, sheet_name = 0, index_col = 0 )
else:
    #Make a new results df
    results_df = pd.DataFrame( np.nan, index = ids, columns = columns )
    results_df.index.name = 'participantid'

for pid in ids:
    df_cur = df[ df['participantid'] == pid]
    
    #Edge stuff
    results_df.loc[pid, 'cat'] = 'edge0'
    if df_cur['edges'].sum() == 41:
        results_df.loc[pid, 'cat'] = 'edge41'
        
    if pd.isnull( results_df.loc[pid,:] ).any():
        print('working on pid {0}'.format(pid))
        P = df_cur.loc[ :, 'px':'pw' ].to_numpy()
        C = df_cur.loc[ :, 'x':'w' ].to_numpy()
        results_df.loc[pid,'PQR_risk'] = PQR_risk_test_find_e(C,P)
        
results_df.to_excel(results_file)

